import boto3
from concurrent.futures import ThreadPoolExecutor, as_completed
import PyPDF2
import io
import zipfile
from datasets import Dataset
from typing import List, Dict, Any
import tempfile
from botocore.client import UNSIGNED
from boto3.session import Config
from tqdm import tqdm


def process_pdf(pdf_bytes: bytes, pdf_filename: str) -> List[Dict[str, Any]]:
    """
    Process a single PDF file and extract text from all its pages.
    """
    pages = []
    pdf_io = io.BytesIO(pdf_bytes)
    try:
        reader = PyPDF2.PdfReader(pdf_io, strict=False)
    except PyPDF2.errors.PdfReadError:
        return []  # Return empty list if PDF is corrupted

    # Try to decrypt the PDF if it's encrypted
    if reader.is_encrypted:
        return []  # Return empty list if decryption fails

    for page_num, page in enumerate(reader.pages):
        try:
            text = page.extract_text()
        except:
            return []
        pages.append(
            {
                "url": url,
                "pdf_file": pdf_filename,
                "page_number": page_num,
                "text": text,
            }
        )

    return pages


def download_s3_zip(bucket: str, key: str) -> bytes:
    """
    Download a file from S3 using unsigned requests.
    """
    s3 = boto3.client("s3")
    response = s3.get_object(Bucket=bucket, Key=key)
    return response["Body"].read()


def process_zip_contents(zip_bytes: bytes) -> List[Dict[str, Any]]:
    """
    Process all PDFs in a zip file in parallel.
    """
    all_pages = []

    with zipfile.ZipFile(io.BytesIO(zip_bytes)) as z:
        pdf_files = [(z.read(f), f) for f in z.namelist() if f.lower().endswith(".pdf")]

        with ThreadPoolExecutor(max_workers=20) as executor:
            future_to_pdf = {
                executor.submit(process_pdf, pdf_bytes, pdf_filename): pdf_filename
                for pdf_bytes, pdf_filename in pdf_files
            }

            for future in tqdm(
                as_completed(future_to_pdf),
                total=len(future_to_pdf),
                desc="Processing PDFs",
            ):
                pdf_filename = future_to_pdf[future]
                pages = future.result()
                all_pages.extend(pages)

    return all_pages


def download_digital_corpora(dataset: Dataset, s3_urls_column: str) -> Dataset:
    """
    Downloads zip files from S3 URLs, extracts PDFs from them in parallel,
    and returns a Huggingface Dataset where each row corresponds to the text
    and bytes extracted from a single page of a PDF.

    Args:
        s3_urls: List of S3 URLs in the format 's3://bucket/key'
        max_workers: Maximum number of parallel workers for PDF processing

    Returns:
        Dataset: Huggingface Dataset containing the extracted text and bytes
    """
    all_pages = []
    s3_urls = dataset[s3_urls_column]
    all_zip_bytes = []
    for url in s3_urls:
        bucket = url.split("/")[2]
        key = "/".join(url.split("/")[3:])
        # Download zip file from S3
        try:
            zip_bytes = download_s3_zip(bucket, key)
            all_zip_bytes.append(zip_bytes)
        except Exception as e:
            breakpoint()
            print(f"Error downloading {url}: {e}")
    print(f"Downloaded {len(all_zip_bytes)} zip files")
    if len(all_zip_bytes) == 0:
        return Dataset.from_list([])
    dataset = Dataset.from_dict({"zip_bytes": all_zip_bytes})
    if len(dataset) == 0:
        breakpoint()
    return Dataset.from_dict({"zip_bytes": all_zip_bytes})


def dataset_process_zip_bytes(dataset: Dataset, zip_bytes_column: str) -> Dataset:
    # Process zip contents in parallel
    for zip_bytes in dataset[zip_bytes_column]:
        pages = process_zip_contents(zip_bytes)
        all_pages.extend(pages)
    return Dataset.from_list(all_pages)


# Example usage:
if __name__ == "__main__":
    s3_urls = [
        "s3://digitalcorpora/corpora/files/CC-MAIN-2021-31-PDF-UNTRUNCATED/zipfiles/0000-0999/0000.zip"
    ]

    dataset = download_digital_corpora(s3_urls)
    print(f"Processed {len(dataset)} pages")
